#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 25 21:35:12 2022

@author: matthewroberts
"""

import os
import sys
import numpy as np
import glob

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.gridspec as gridspec
import matplotlib.colors as colors
import matplotlib.cm as cm

import cartopy.crs as crs
import cartopy.io.img_tiles as cimgt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import pyproj
import pyproj
from pyproj import Geod

import datetime as dt
import wrf
from netCDF4 import Dataset
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

start_time = dt.datetime.now()
print('\nProgram Started: {0}'.format(start_time))
print('===================')

# Creates shaded relief background on maps
from cartopy.io.img_tiles import GoogleTiles
class ShadedReliefESRI(GoogleTiles):
    # shaded relief
    def _image_url(self, tile):
        x, y, z = tile
        url = ('https://server.arcgisonline.com/ArcGIS/rest/services/' \
               'World_Shaded_Relief/MapServer/tile/{z}/{y}/{x}.jpg').format(
               z=z, y=y, x=x)
        return url
    
def planarWind(uVal,vVal,bearing):
    # Get angle difference between environmental wind and cross
    # section angle (All angles assumed starting at N on 0-360 plane)
    D = bearing - np.rad2deg((np.arctan2(uVal,vVal)))
    pWind = (np.sqrt(uVal**2+vVal**2))*np.cos(np.deg2rad(D))
    return pWind

def relax_zone_remover (input, sr):
    output = input
    for _ in range(sr):
        output = np.delete(output, -1, 0)
        output = np.delete(output, -1, 1)
    return output

#########################
# Inputs
#########################

AvgXsect = True

# keyword = 'bear_fire' #for naming plots
keyword = 'caldor_fire' #for naming plots

mainpath = '/Users/matthewroberts/Documents/Projects/LEAPHI'
filepath = mainpath+'/data/'+keyword+'/_CONTROL' #data files location
savepath = mainpath+'/plots/'+keyword+'/flame_depth' #where to save figure
filelist_og = sorted(glob.glob(filepath+'/wrfout_d03*'))

filelist = []
for f in filelist_og:
    filelist.append(os.path.basename(f))

# testing
# filelist = [filelist[-15]]

filepath = mainpath+'/data/'+keyword
fuels = ['_FUELx8', '_FUELx4', '_FUELx1']

time_arr = []
for f in range(len(filelist)):
    
    plot_hflux = []
    plot_lfn = []
    
    plot_xarrays = []
    plot_fireCross = []
    plot_fireWidth = []
    plot_fireWidthYloc = []
    
    for ff in fuels:
        
        wrf_file = Dataset(filepath+'/'+ff+'/'+filelist[f],mode='r')
    
        # Create timestamp and time array
        time = wrf.extract_times(wrf_file,wrf.ALL_TIMES)
        tstamp = str(pd.Timestamp(time[0]))
        if (ff == '_FUELx8'):
            time_arr.append(tstamp)
        print('\nOpening '+tstamp+'...', end="", flush=True)
    
        # Standard lat/lon grid
        terrain = wrf.getvar(wrf_file, "HGT", timeidx=-1)
        lat = wrf.getvar(wrf_file, "XLAT", timeidx=-1)
        lon = wrf.getvar(wrf_file, "XLONG", timeidx=-1)
    
        # Fire grid variables
        fgrnhfx = wrf.getvar(wrf_file, "FGRNHFX", timeidx=-1) #Heat flux ground fire
        fcanhfx = wrf.getvar(wrf_file, "FCANHFX", timeidx=-1) #Heat flux crown fire
        f_hfx = fgrnhfx+fcanhfx #total fire heat flux
        f_hfx = np.ma.masked_array(f_hfx, f_hfx <= 1)
        lfn1 = wrf.getvar(wrf_file, 'LFN', timeidx=-1)
        fuel = wrf.getvar(wrf_file, 'NFUEL_CAT', timeidx=-1)
        # loading lat/lons of fire mesh
        flat = wrf.getvar(wrf_file, 'FXLAT', timeidx=-1)
        flon = wrf.getvar(wrf_file, 'FXLONG', timeidx=-1)
    
        # subgrid ratio
        sr = int(wrf_file.dimensions['west_east_subgrid'].size)/int(wrf_file.dimensions['west_east_stag'].size)
        dx = int(wrf_file.DX)/sr
        # removing the relaxation zones of the level-set function
        lfn1 = relax_zone_remover(lfn1, int(sr))
        f_hfx = relax_zone_remover(f_hfx, int(sr))
        fuel = relax_zone_remover(fuel, int(sr))
        flon = relax_zone_remover(flon, int(sr))
        flat = relax_zone_remover(flat, int(sr))
        
        # REGRIDDING
        # Suppose the 2D array is coarsened by 1000 meters
        coarseness = int(1000./dx) #regrid to 1km GOES pixels
        
        # Start by finding the next highest multiple of coarseness
        shape = np.array(f_hfx.shape, dtype=int)
        new_shape = coarseness * np.ceil(shape / coarseness).astype(int)
        
        # Create the zero-padded array and assign it with the old density
        zp_f_hfx, zp_flon, zp_flat = np.zeros(new_shape), np.zeros(new_shape), np.zeros(new_shape)
        zp_f_hfx[:shape[0], :shape[1]] = f_hfx
        zp_flat[:shape[0], :shape[1]] = flat
        zp_flon[:shape[0], :shape[1]] = flon
        
        # Find mean of old density inside new coarse density and fill
        # heat flux
        temp = zp_f_hfx.reshape((new_shape[0] // coarseness, coarseness,
                                 new_shape[1] // coarseness, coarseness))
        coarse_f_hfx = np.mean(temp, axis=(1,3))
        coarse_f_hfx = np.ma.masked_less_equal(coarse_f_hfx, 0)
        # longitude
        temp = zp_flon.reshape((new_shape[0] // coarseness, coarseness,
                                new_shape[1] // coarseness, coarseness))
        coarse_flon = np.mean(temp, axis=(1,3))
        # latitude
        temp = zp_flat.reshape((new_shape[0] // coarseness, coarseness,
                                new_shape[1] // coarseness, coarseness))
        coarse_flat = np.mean(temp, axis=(1,3))
    
        print('Interp x-sections...', end="", flush=True)
        
        x_idx = np.arange(-3,4,1)
        y_idx = np.arange(-3,4,1)[::-1]
        
        if (AvgXsect == True):
            _lon_cross = []
            _lat_cross = []
            _lfn_cross = []
            _fire_cross = []
    
            for ll in range(len(x_idx)):
    
                ll_pt = wrf.CoordPair(lat=coarse_flat[0,0], lon=coarse_flon[0,0])
                proj = wrf.LambertConformal(map_proj=1,
                                            TRUELAT1 = 39.71475,
                                            TRUELAT2 = 39.71475,
                                            MOAD_CEN_LAT = 39.71476,
                                            STAND_LON = -121.1799,
                                            POLE_LAT = 90.,
                                            POLE_LON = 0.,
                                            dx = dx*coarseness,
                                            dy = dx*coarseness)
                
                if (keyword == 'bear_fire'):
                    # Central cross section BEAR
                    start_pt = wrf.ll_to_xy_proj(39.85, -121.0,
                                                  map_proj=1, 
                                                  truelat1=39.71475, 
                                                  truelat2=39.71475, 
                                                  stand_lon=-121.1799, 
                                                  ref_lat=coarse_flat[0,0], 
                                                  ref_lon=coarse_flon[0,0],
                                                  known_x=0, 
                                                  known_y=0, 
                                                  dx=dx*coarseness, 
                                                  dy=dx*coarseness)
                    end_pt = wrf.ll_to_xy_proj(39.5, -121.4,
                                                  map_proj=1, 
                                                  truelat1=39.71475, 
                                                  truelat2=39.71475, 
                                                  stand_lon=-121.1799, 
                                                  ref_lat=coarse_flat[0,0], 
                                                  ref_lon=coarse_flon[0,0],
                                                  known_x=0, 
                                                  known_y=0, 
                                                  dx=dx*coarseness, 
                                                  dy=dx*coarseness)
                if (keyword == 'caldor_fire'):            
                    # Central cross section CALDOR
                    start_pt = wrf.ll_to_xy_proj(38.83, -120.3,
                                                  map_proj=1, 
                                                  truelat1=39.71475, 
                                                  truelat2=39.71475, 
                                                  stand_lon=-121.1799, 
                                                  ref_lat=coarse_flat[0,0], 
                                                  ref_lon=coarse_flon[0,0],
                                                  known_x=0, 
                                                  known_y=0, 
                                                  dx=dx*coarseness, 
                                                  dy=dx*coarseness)
                    end_pt = wrf.ll_to_xy_proj(38.5, -120.68,
                                                  map_proj=1, 
                                                  truelat1=39.71475, 
                                                  truelat2=39.71475, 
                                                  stand_lon=-121.1799, 
                                                  ref_lat=coarse_flat[0,0], 
                                                  ref_lon=coarse_flon[0,0],
                                                  known_x=0, 
                                                  known_y=0, 
                                                  dx=dx*coarseness, 
                                                  dy=dx*coarseness)
                
                startLat = coarse_flat[start_pt.data[0]+x_idx[ll],start_pt.data[1]+y_idx[ll]]
                startLon = coarse_flon[start_pt.data[0]+x_idx[ll],start_pt.data[1]+y_idx[ll]]
    
                endLat = coarse_flat[end_pt.data[0]+x_idx[ll],end_pt.data[1]+y_idx[ll]]
                endLon = coarse_flon[end_pt.data[0]+x_idx[ll],end_pt.data[1]+y_idx[ll]]
    
                # Define the cross section start and end points (NE to SW)
                start_point = wrf.CoordPair(lat=startLat, lon=startLon)
                end_point = wrf.CoordPair(lat=endLat, lon=endLon)
    
                # Define the cross section start and end points (NE to SW)
                # start_point = wrf.CoordPair(lat=39.9, lon=-120.92)
                # end_point = wrf.CoordPair(lat=39.47, lon=-121.45)

                fire_cross = wrf.interpline(coarse_f_hfx, projection=proj, ll_point=ll_pt,#wrfin = wrf_file,
                                           start_point = start_point,
                                           end_point = end_point,
                                           latlon = True, meta = True)
                lon_cross = wrf.interpline(coarse_flon, projection=proj, ll_point=ll_pt,#wrfin = wrf_file,
                                           start_point = start_point,
                                           end_point = end_point,
                                           latlon = True, meta = True)
                lat_cross = wrf.interpline(coarse_flat, projection=proj, ll_point=ll_pt,#wrfin = wrf_file,
                                           start_point = start_point,
                                           end_point = end_point,
                                           latlon = True, meta = True)
    
                # Flip everything so NE is on the right and SW is on the left
                # lfn_cross = lfn_cross[::-1]
                lon_cross = lon_cross[::-1]
                lat_cross = lat_cross[::-1]
                fire_cross = fire_cross[::-1]
                # print(f_hfx)
                coord_pairs = zip(lon_cross, lat_cross) #x labels
                x_array = np.arange(len(fire_cross))
    
                if (ll == 0):
                    idx = lon_cross.shape[0]

                else:
                    #if x-sect are different sizes trim to fit
                    if (lon_cross.shape[0] > idx):
                        # idx = _lon_cross.shape[0]
                        lon_cross = lon_cross[:idx]
                        lat_cross = lat_cross[:idx]
                        # lfn_cross = lfn_cross[:idx]
                        fire_cross = fire_cross[:idx]
                    
                print(str(ll)+'.', end="", flush=True)
                # print(lon_cross.shape[0])
                
                _lon_cross.append(lon_cross)
                _lat_cross.append(lat_cross)
                # _lfn_cross.append(lfn_cross)
                _fire_cross.append(fire_cross)
    
        # trim arrays to length of smallest array
        lens = []
        for i in _lon_cross:
            lens.append(len(i))
        lens = np.asarray(lens)
        minlength = np.nanmin(lens)
        for i in range(len(_lon_cross)):
            _lon_cross[i] = _lon_cross[i][:minlength]
            _lat_cross[i] = _lat_cross[i][:minlength]
            # _lfn_cross[i] = _lfn_cross[i][:minlength]
            _fire_cross[i] = _fire_cross[i][:minlength]
        
        _lon_cross = np.stack(_lon_cross[:minlength], axis=0)
        _lat_cross = np.stack(_lat_cross[:minlength], axis=0)
        # _lfn_cross = np.stack(_lfn_cross[:minlength], axis=0)
        _fire_cross = np.stack(_fire_cross[:minlength], axis=0)
        
        lon_cross = np.nanmean(_lon_cross,axis=0).squeeze()
        lat_cross = np.nanmean(_lat_cross,axis=0).squeeze()
        # lfn_cross = np.nanmin(_lfn_cross,axis=0).squeeze()
        fire_cross = np.nanmean(_fire_cross,axis=0).squeeze()
        
        if (ff == '_FUELx8'):
            norm_max = np.nanmax(fire_cross)
            
        norm_fire_cross = fire_cross/norm_max
        
        # print('\n'+str(norm_max))
        
        x_array = np.arange(len(fire_cross))
        
        # convert lat/lon coords to distance from fire front
        fire_front_idx = np.argmax(norm_fire_cross)
        g = Geod(ellps='WGS84')
        x_dist_array = []
        for l in range(len(lon_cross)):
            # 2D distance in meters with longitude, latitude of the points
            azimuth1, azimuth2, distance_2d = g.inv(lon_cross[fire_front_idx], 
                                                    lat_cross[fire_front_idx], 
                                                    lon_cross[l], lat_cross[l]) 
            # print(distance_2d)
            if (l < fire_front_idx):
                distance_2d = distance_2d/-1000. # convert to km, in front of fire front
            else:
                distance_2d = distance_2d/1000. # convert to km
                
            x_dist_array.append(distance_2d)

        # increase spatial resolution of points
        from scipy import interpolate
        interpFunc = interpolate.interp1d(x_dist_array, norm_fire_cross)
        x_dist_array_interp = np.linspace(x_dist_array[0], x_dist_array[-1], 400)
        norm_fire_cross_interp = interpFunc(x_dist_array_interp)
        
        # Find width of fireline based on e-folding
        loc_width = .36*np.nanmax(norm_fire_cross_interp)
        fire_width = x_dist_array_interp[norm_fire_cross_interp > loc_width]
        yloc = np.zeros(np.shape(fire_width)) + loc_width
        
        # Plan view data
        plot_hflux.append(f_hfx)
        plot_lfn.append(lfn1)
        
        # Cross section data
        plot_xarrays.append(x_dist_array_interp)
        plot_fireCross.append(norm_fire_cross_interp)
        plot_fireWidth.append(fire_width)
        plot_fireWidthYloc.append(yloc)
        
        # Cross section and coords of fire area
        # lfn_x = x_array[lfn_cross < 0]
        # lfn_y = ter_cross[lfn_cross < 0]
        # lfn_cross = lfn_cross[lfn_cross < 0]
    #########################
    # Plotting
    #########################
    print('\nPlotting...')
    fig = plt.figure(1,figsize=(10,5))
    spec = gridspec.GridSpec(nrows=40, ncols=85, figure=fig)
    
    # Create a Terrain instance.
    esri_terrain = ShadedReliefESRI()
    
    #### plan view fuelx1 ####
    
    ax1 = fig.add_subplot(spec[:19,:25], projection=ccrs.PlateCarree()) #rows,cols
    # Limit the extent of the map to a small longitude/latitude range.
    if (keyword == 'bear_fire'):
        extent = [-121.31, -121.05, 39.58, 39.77]
        # extent = [-121.5,  -120.9,  39.48,39.9] 
    if (keyword == 'caldor_fire'):
        # extent = [-120.5, -120.35, 39.52, 39.73]
        extent = [-120.58, -120.35, 38.58, 38.74]
    ax1.set_extent(extent, crs=ccrs.PlateCarree())

    # Add the Stamen data at zoom level 12 (max).
    ax1.add_image(esri_terrain, 12)

    cmap1 = cm.get_cmap('hot',lut=20)
    # mask small values
    fire = np.ma.masked_array(plot_hflux[2], plot_hflux[2] <= 1)
    flux_plot = ax1.pcolormesh(flon,flat,fire/1000.,vmin=1,vmax=35,#alpha=.5,
                              cmap=cmap1,transform=ccrs.PlateCarree(), zorder=11)
    
    fireTU5_idx = np.where(fuel == 165)
    newFire = fire[fireTU5_idx[0],fireTU5_idx[1]]
    print('Fuelx1 Max TU5 Heat Flux: '+str(np.ma.max(newFire/1000.)))
    
    # Draw the fire contours
    ax1.contour(flon,flat, plot_lfn[2], 0, colors='k', linewidth=2,  transform=ccrs.PlateCarree(), zorder=12)

    for jjj in [0,-1]:
        Xsect = ax1.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=5,color='white',
                         alpha=.5, zorder=14, transform=ccrs.PlateCarree())
        Xsect = ax1.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=3,color='k',
                         linestyle='--', zorder=14, transform=ccrs.PlateCarree())
        
    ax1.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
    ax1.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
    lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
    lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
    ax1.xaxis.set_major_formatter(lon_formatter) # set lons
    ax1.yaxis.set_major_formatter(lat_formatter) # set lats
    ax1.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, bottom=True, top=True, zorder=99)
    ax1.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, left=True, right=True, zorder=99)
    
    # Show countries and color land
    ax1.add_feature(cfeature.LAND, facecolor='k', alpha=.2, zorder=10)
    #Download county shapefiles
    reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
    counties = list(reader.geometries())
    COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
    ax1.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=13)
    
    ax1.outline_patch.set_linewidth(2)
    ax1.outline_patch.set_zorder(99)
        

    
    # ax1.set_title('Fuelx1', loc='center', fontsize=12, fontweight='bold',zorder=10)
    
    ##################
    
    #### plan view fuelx4 ####
    
    ax2 = fig.add_subplot(spec[:19,27:52], projection=ccrs.PlateCarree()) #rows,cols
    ax2.set_extent(extent, crs=ccrs.PlateCarree())

    # Add the Stamen data at zoom level 12 (max).
    ax2.add_image(esri_terrain, 12)

    cmap1 = cm.get_cmap('hot',lut=20)
    # mask small values
    fire = np.ma.masked_array(plot_hflux[1], plot_hflux[1] <= 1)
    flux_plot = ax2.pcolormesh(flon,flat,fire/1000.,vmin=1,vmax=35,#alpha=.5,
                              cmap=cmap1,transform=ccrs.PlateCarree(), zorder=11)
    
    fireTU5_idx = np.where(fuel == 165)
    newFire = fire[fireTU5_idx[0],fireTU5_idx[1]]
    print('Fuelx4 Max TU5 Heat Flux: '+str(np.ma.max(newFire/1000.)))
    
    # Draw the fire contours
    ax2.contour(flon,flat, plot_lfn[1], 0, colors='k', linewidth=2,  transform=ccrs.PlateCarree(), zorder=12)

    for jjj in [0,-1]:
        Xsect = ax2.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=5,color='white',
                         alpha=.5, zorder=14, transform=ccrs.PlateCarree())
        Xsect = ax2.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=3,color='k',
                         linestyle='--', zorder=14, transform=ccrs.PlateCarree())
        
    ax2.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
    ax2.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
    lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
    lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
    ax2.xaxis.set_major_formatter(lon_formatter) # set lons
    ax2.yaxis.set_major_formatter(lat_formatter) # set lats
    ax2.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, bottom=True, top=True, zorder=99)
    ax2.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=0, rotation=20, left=True, right=True, labelcolor='white', zorder=99)
    
    # Show countries and color land
    ax2.add_feature(cfeature.LAND, facecolor='k', alpha=.2, zorder=10)
    #Download county shapefiles
    reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
    counties = list(reader.geometries())
    COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
    ax2.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=13)
    
    ax2.outline_patch.set_linewidth(2)
    ax2.outline_patch.set_zorder(99)
    
    # ax2.set_title('Fuelx4', loc='center', fontsize=12, fontweight='bold',zorder=10)
    
    ##################
    
    #### plan view fuelx8 ####
    
    ax3 = fig.add_subplot(spec[:19,54:79], projection=ccrs.PlateCarree()) #rows,cols
    ax3.set_extent(extent, crs=ccrs.PlateCarree())

    # Add the Stamen data at zoom level 12 (max).
    ax3.add_image(esri_terrain, 12)

    cmap1 = cm.get_cmap('hot',lut=20)
    # mask small values
    fire = np.ma.masked_array(plot_hflux[00], plot_hflux[0] <= 1)
    flux_plot = ax3.pcolormesh(flon,flat,fire/1000.,vmin=1,vmax=35,#alpha=.5,
                              cmap=cmap1,transform=ccrs.PlateCarree(), zorder=11)
    
    fireTU5_idx = np.where(fuel == 165)
    newFire = fire[fireTU5_idx[0],fireTU5_idx[1]]
    print('Fuelx8 Max TU5 Heat Flux: '+str(np.ma.max(newFire/1000.)))
    
    # Draw the fire contours
    ax3.contour(flon,flat, plot_lfn[0], 0, colors='k', linewidth=2,  transform=ccrs.PlateCarree(), zorder=12)
    
    for jjj in [0,-1]:
        Xsect = ax3.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=5,color='white',
                         alpha=.5, zorder=14, transform=ccrs.PlateCarree())
        Xsect = ax3.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=3,color='k',
                         linestyle='--', zorder=14, transform=ccrs.PlateCarree())

    ax3.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
    ax3.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
    lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
    lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
    ax3.xaxis.set_major_formatter(lon_formatter) # set lons
    ax3.yaxis.set_major_formatter(lat_formatter) # set lats
    ax3.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, bottom=True, top=True, zorder=99)
    ax3.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=0, rotation=20, left=True, right=True, labelcolor='white', zorder=99)
    
    # Show countries and color land
    ax3.add_feature(cfeature.LAND, facecolor='k', alpha=.2, zorder=10)
    #Download county shapefiles
    reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
    counties = list(reader.geometries())
    COUNTIES = cfeature.ShapelyFeature(counties, ccrs.PlateCarree())
    ax3.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=13)
    
    ax3.outline_patch.set_linewidth(2)
    ax3.outline_patch.set_zorder(99)
        
    ax5 = fig.add_subplot(spec[1:18,82:84]) 
    cbar = fig.colorbar(flux_plot, cax=ax5)
    cbar.set_label('kW/m^2', fontsize=12, fontweight='heavy')
    
    # ax2.set_title('Fuelx4', loc='center', fontsize=12, fontweight='bold',zorder=10)
    
    ##################
    
    # Title and axis labels
    ax2.set_title(tstamp+' UTC\n', loc='center', fontsize=12, fontweight='bold',zorder=10)
    
    #### cross section ####
    
    ax4 = fig.add_subplot(spec[25:,10:70]) #rows,cols
    
    linecolors = ['maroon', 'mediumvioletred', 'dodgerblue']
    scenarios = ['Fuelx8', 'Fuelx4', 'Fuelx1']
    
    for xx in reversed(range(len(plot_xarrays))):
        ax4.plot(plot_xarrays[xx],plot_fireCross[xx], color=linecolors[xx], linewidth=3)
        ax4.plot(plot_fireWidth[xx], plot_fireWidthYloc[xx], color=linecolors[xx], linestyle='--', linewidth=3)
        # print stats
        if (len(plot_fireWidth[xx] > 0)):
            print(scenarios[xx]+' Fire Width: '+str(np.abs(plot_fireWidth[xx][0]-plot_fireWidth[xx][-1])))
        
    # Set x/y lims
    ax4.set_xlim(-13,21)
    ax4.set_ylim(0,1)
    ax4.set_xticks(np.arange(-12,22,2))
    ax4.grid('on')
    
    ax4.set_xlabel(r'Fire front distance [km]', fontsize=12)
    ax4.set_ylabel(r'Normalized $H_s$', fontsize=12)
    
    # Set border and ticks around plot
    for axis in ['top', 'bottom', 'left', 'right']:
        ax4.spines[axis].set_linewidth(2)  # change width
        ax4.spines[axis].set_zorder(99)
    
    ax4.xaxis.set_tick_params(width=2,length=5,direction="in", labelsize=12, zorder=99, bottom=True, top=True)
    ax4.yaxis.set_tick_params(width=2,length=5,direction="in", labelsize=12, zorder=99, left=True, right=True)

    ##################
    # plt.show()
    # sys.exit()

    savetime = dt.datetime.strptime(tstamp, "%Y-%m-%d %H:%M:%S")
    savetime = dt.datetime.strftime(savetime, "%Y%m%d_%H%M")
    # save
    plt.savefig(savepath+'/'+keyword+'_flameDepth_'+savetime+'.png',bbox_inches='tight',dpi=260)
    plt.close('all')
    print('Saved '+keyword+'_flameDepth_'+savetime+'.png')
        
    """
    # Inset map
    ax2 = fig.add_subplot(spec[0:13,29:39], projection=ccrs.PlateCarree()) #rows,cols
    # extent = [-121.5, -120.85, 39.44, 39.91]
    # ax2.set_extent(extent, crs=ccrs.PlateCarree())
    terMap = ax2.contourf(lon,lat,terrain, cmap='terrain', extend='both',
                          levels=np.arange(-400,2400,100), zorder=7,
                          transform=ccrs.PlateCarree())
    perimPlot = ax2.contour(flon, flat, lfn1, levels=[0], colors='k', linewidth=3,
                            zorder=8, transform=ccrs.PlateCarree())
    for jjj in [0,-1]:
        Xsect = ax2.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=3,color='white',
                         alpha=.5, zorder=10, transform=ccrs.PlateCarree())
        Xsect = ax2.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=1,color='k',
                         linestyle='--', zorder=10, transform=ccrs.PlateCarree())
    ax2.set_xticks([])
    ax2.set_yticks([])
    # Set border around plot
    ax2.spines['top'].set_zorder(10)
    ax2.spines['bottom'].set_zorder(10)
    ax2.spines['left'].set_zorder(10)
    ax2.spines['right'].set_zorder(10)
    """
    
    # plt.show()
        # sys.exit()
        # plt.savefig('/Users/matthewroberts/Desktop/tfigs/'+tstamp+'_base_xsect.png',dpi=120)
    
        # create timestamp for saving files
        # savetime = dt.datetime.strptime(tstamp, "%Y-%m-%d %H:%M:%S")
        # savetime = dt.datetime.strftime(savetime, "%Y%m%d_%H%M")
        # # save
        # plt.savefig(savepath+'/'+keyword+'_'+savetime+'.png',bbox_inches='tight',dpi=120)
        # plt.close('all')
        # print('Saved '+keyword+'_'+savetime+'.png')

# np.save(mainpath+'/data/'+keyword+'_normMaxHs.npy', np.asarray(norm_max))

end_time = dt.datetime.now()
print('===================')
print('Total elapsed time: {0}\n'.format(end_time-start_time))